{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Online processing of volumetric data\n",
    "This is a simple demo on simulated toy 3d data for motion correction, source extraction and deconvolution comparing CaImAn batch with CaImAn online (OnACID)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    ipython().run_line_magic('load_ext', 'autoreload')\n",
    "    ipython().run_line_magic('autoreload', 2)\n",
    "except:\n",
    "    pass\n",
    "\n",
    "import logging\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from scipy.stats.qmc import Halton\n",
    "\n",
    "import caiman as cm\n",
    "from caiman.utils.visualization import nb_view_patches3d\n",
    "import caiman.source_extraction.cnmf as cnmf\n",
    "from caiman.source_extraction.cnmf.utilities import gaussian_filter\n",
    "\n",
    "import bokeh.plotting as bpl\n",
    "bpl.output_notebook()\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.WARNING)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "\n",
    "# Set play_movies to false if you want to disable play of movies, e.g. for remote-hosted Jupyter environments\n",
    "play_movies = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a function to create some toy data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def gen_data(p=1, noise=.05, T=500, framerate=30, firerate=2., motion=True, init_batch=200):\n",
    "    if p == 2:\n",
    "        gamma = np.array([1.5, -.55])\n",
    "    elif p == 1:\n",
    "        gamma = np.array([.9])\n",
    "    else:\n",
    "        raise\n",
    "    dims = (70, 50, 10)  # size of image\n",
    "    sig = (4, 4, 2)      # neurons size\n",
    "    bkgrd = 1.           # background magnitude\n",
    "    N = 20               # number of neurons\n",
    "    np.random.seed(42)\n",
    "    centers = np.round(np.array(sig) + (np.array(dims)-2*np.array(sig)) * \n",
    "                       Halton(d=3, scramble=False).random(n=N)).astype(int)\n",
    "    \n",
    "    S = np.random.rand(N, T) < firerate / float(framerate)\n",
    "    S[:, 0] = 0\n",
    "    S[N//2:,:init_batch] = 0 # half of the neurons aren't active in the initial batch\n",
    "    C = S.astype(np.float32)\n",
    "    for i in range(2, T):\n",
    "        if p == 2:\n",
    "            C[:, i] += gamma[0] * C[:, i - 1] + gamma[1] * C[:, i - 2]\n",
    "        else:\n",
    "            C[:, i] += gamma[0] * C[:, i - 1]\n",
    "            \n",
    "    if motion:\n",
    "        sig_m = np.array(sig)\n",
    "        shifts = -np.transpose([np.convolve(np.random.randn(T-10), np.ones(11)/11*s) for s in sig_m])\n",
    "    else:\n",
    "        sig_m = np.zeros(3, dtype=int)\n",
    "        shifts = None\n",
    "        \n",
    "    A = np.zeros(tuple(np.array(dims) + sig_m * 4) + (N,), dtype='float32')\n",
    "    for i in range(N):\n",
    "        A[tuple(centers[i] + sig_m*2) + (i,)] = 1.\n",
    "    A = gaussian_filter(A, sig + (0,), truncate=1.5)\n",
    "    A /= np.sqrt(np.sum(np.sum(np.sum(A**2,0),0),0))  \n",
    "    f = np.ones(T, dtype='float32')\n",
    "    b = bkgrd * np.ones(A.shape[:-1], dtype='float32')  \n",
    "\n",
    "    Yr = np.outer(b.reshape(-1, order='F'), f) + A.reshape((-1, N), order='F').dot(C)\n",
    "    Yr += noise * np.random.randn(*Yr.shape)\n",
    "    Y = Yr.T.reshape((-1,) + tuple(np.array(dims) + sig_m * 4), order='F').astype(np.float32)\n",
    "    if motion:\n",
    "        Y = np.array([cm.motion_correction.apply_shifts_dft(img, (sh[0], sh[1], sh[2]), 0,\n",
    "                                                            is_freq=False, border_nan='copy')\n",
    "                           for img, sh in zip(Y, -shifts)])\n",
    "        Y = Y[:, 2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]\n",
    "        A = A[2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]\n",
    "        b = b[2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]\n",
    "    return Y, C, S, A.reshape((-1, N), order='F'), b.reshape(-1, order='F'), f, centers, dims, shifts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select file(s) to be processed\n",
    "Create a file with a toy 3d dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fname = os.path.join(cm.paths.caiman_datadir(), 'example_movies', 'demoMovie3D.nwb')\n",
    "Y, C, S, A, b, f, centers, dims, shifts = gen_data()\n",
    "cm.movie(Y).save(fname)\n",
    "print(fname)\n",
    "N, T = C.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(9,3))\n",
    "\n",
    "plt.subplot(121)\n",
    "plt.colorbar(plt.imshow(A.T.dot(A)))\n",
    "plt.title('overlap of A')\n",
    "np.max(A.T.dot(A)-np.eye(N))\n",
    "\n",
    "plt.subplot(122)\n",
    "plt.colorbar(plt.imshow(np.corrcoef(C)))\n",
    "plt.title('correlation of C')\n",
    "np.max(A.T.dot(A)-np.eye(N)), np.max(np.corrcoef(C)-np.eye(N))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspect the data\n",
    "First, view a max-projection of the correlation image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Y = cm.load(fname)\n",
    "Cn = cm.local_correlations(Y, swap_dim=False)\n",
    "d1, d2, d3 = dims\n",
    "x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))\n",
    "scale = 6/x\n",
    "fig = plt.figure(figsize=(scale*x, scale*y))\n",
    "\n",
    "axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])\n",
    "plt.imshow(Cn.max(2).T, cmap='gray')\n",
    "plt.title('Max.proj. z')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "\n",
    "axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])\n",
    "plt.imshow(Cn.max(0), cmap='gray')\n",
    "plt.title('Max.proj. x')\n",
    "plt.xlabel('z')\n",
    "plt.ylabel('y')\n",
    "\n",
    "axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])\n",
    "plt.imshow(Cn.max(1).T, cmap='gray')\n",
    "plt.title('Max.proj. y')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('z');\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Play the movie (optional). \n",
    "This will require loading the movie in memory which in general is not needed by the pipeline. Displaying the movie uses the OpenCV library. Press `q` to close the video panel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if play_movies:\n",
    "    Y[...,5].play(magnification=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "params_dict = {'fnames': fname,               # filename(s) to be processed\n",
    "               'fr': 30,                      # frame rate (Hz)\n",
    "               'K': N,                        # (upper bound on) number of components\n",
    "               'is3D': True,                  # flag for volumetric data\n",
    "               'decay_time': 1,               # length of typical transient in seconds\n",
    "               'gSig': (4, 4, 2),             # gaussian width of a 3D gaussian kernel, which approximates a neuron\n",
    "               'p': 1,                        # order of the autoregressive system\n",
    "               'nb': 1,                       # number of background components\n",
    "               'only_init': False,            # whether to run only the initialization\n",
    "               'normalize_init': False,       # whether to equalize the movies during initialization\n",
    "               'motion_correct': True,        # flag for performing motion correction\n",
    "               'max_shifts': (4, 4, 2),       # maximum allowed rigid shifts (in pixels)\n",
    "               'nonneg_movie': False,         # flag for producing a non-negative movie\n",
    "               'init_batch': 200,             # length of mini batch for initialization\n",
    "               'init_method': 'cnmf',         # initialization method for initial batch\n",
    "               'batch_update_suff_stat': True,# flag for updating sufficient statistics (used for updating shapes)\n",
    "               'thresh_overlap': 0,           # space overlap threshold for screening new components\n",
    "              }\n",
    "opts = cnmf.params.CNMFParams(params_dict=params_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run batch version for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)\n",
    "if 'dview' in locals():\n",
    "    cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %% fit with batch object\n",
    "cnmB = cnmf.CNMF(n_processes=n_processes, params=opts, dview=dview)\n",
    "cnmB.fit_file(motion_correct=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# STOP CLUSTER\n",
    "cm.stop_server(dview=dview)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View the results\n",
    "View components per plane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cnmB.estimates.nb_view_components_3d(image_type='max', dims=dims, axis=2, cmap='viridis');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare with ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_A(cnm):\n",
    "    order = list(map(np.argmax, np.corrcoef(A.T, cnm.estimates.A.T.toarray())[:N,N:]))\n",
    "    plt.subplot(131)\n",
    "    try:\n",
    "        plt.imshow(cnm.estimates.A.T.toarray().reshape((-1,)+dims, order='F').max(0).max(-1))\n",
    "    except:\n",
    "        plt.imshow(np.array(cnm.estimates.A).T.reshape((-1,)+dims, order='F').max(0).max(-1))\n",
    "    plt.title('inferred A')\n",
    "    plt.subplot(132)\n",
    "    plt.imshow(A.T.reshape((-1,)+dims, order='F').max(0).max(-1))\n",
    "    plt.title('true A')\n",
    "    plt.subplot(133)\n",
    "    plt.imshow(Y.max(0).max(-1))\n",
    "    plt.title('max Y projection');\n",
    "\n",
    "    plt.figure(figsize=(5,3))\n",
    "    overlap = cnm.estimates.A.T[order].dot(A)\n",
    "    plt.colorbar(plt.imshow(overlap))\n",
    "    plt.title('overlap')\n",
    "    plt.show()\n",
    "    overlap = overlap.diagonal()\n",
    "    print(f'Overlap of neural shapes   Min: {overlap.min():.4f},  Mean: {overlap.mean():.4f},  Max: {overlap.max():.4f}')\n",
    "    \n",
    "plot_A(cnmB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_C(cnm):\n",
    "    order = list(map(np.argmax, np.corrcoef(C, cnm.estimates.C)[:N,N:]))\n",
    "    if len(order) != len(tuple(order)):\n",
    "        raise \n",
    "\n",
    "    plt.figure(figsize=(12,5))\n",
    "    plt.subplot(211)\n",
    "    plt.plot(cnm.estimates.C[order].T)\n",
    "    plt.title('inferred C')\n",
    "    plt.subplot(212)\n",
    "    plt.plot(C.T)\n",
    "    plt.title('true C')\n",
    "\n",
    "    plt.figure(figsize=(5,3))\n",
    "    corr = np.corrcoef(C, cnm.estimates.C[order])[:N,N:]\n",
    "    plt.colorbar(plt.imshow(corr))\n",
    "    plt.title('correlation')\n",
    "    plt.show()\n",
    "    corr = corr.diagonal()\n",
    "    print(f'Correlation of (denoised) fluor. C   Min: {corr.min():.4f},  Mean: {corr.mean():.4f},  Max: {corr.max():.4f}')\n",
    "\n",
    "plot_C(cnmB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_shifts(cnm):\n",
    "    plt.figure(figsize=(12,5))\n",
    "    plt.subplot(211)\n",
    "    if cnm.params.motion['pw_rigid']:\n",
    "        if len(cnm.estimates.shifts)==T:\n",
    "            est_shifts = np.array(cnm.estimates.shifts)\n",
    "        else:\n",
    "            est_shifts = np.transpose(cnm.estimates.shifts, (1,2,0))\n",
    "        plt.plot(est_shifts[:,0])\n",
    "        print('Correlation with true shifts  ', np.corrcoef(\n",
    "            np.transpose(shifts), est_shifts.T[:,0])[:3,3:].diagonal())\n",
    "    else:\n",
    "        plt.plot(cnm.estimates.shifts)\n",
    "        print('Correlation with true shifts  ', np.corrcoef(\n",
    "            np.transpose(shifts), np.transpose(cnm.estimates.shifts))[:3,3:].diagonal())\n",
    "    plt.title('inferred shifts')\n",
    "    plt.ylabel('pixels')\n",
    "    plt.subplot(212)\n",
    "    for k in (0,1,2):\n",
    "        plt.plot(np.array(shifts)[:,k], label=('x','y','z')[k])\n",
    "    plt.legend()\n",
    "    plt.title('true shifts')\n",
    "    plt.xlabel('frames')\n",
    "    plt.ylabel('pixels')\n",
    "    \n",
    "plot_shifts(cnmB)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run online version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# only half of the neurons are active in the initial batch\n",
    "params_dict['K'] = N//2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %% fit with online object\n",
    "opts = cnmf.params.CNMFParams(params_dict=params_dict)\n",
    "cnmO = cnmf.online_cnmf.OnACID(params=opts)\n",
    "cnmO.fit_online();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View the results\n",
    "View components per plane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cnmO.estimates.nb_view_components_3d(image_type='max', dims=dims, axis=2, cmap='viridis');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare with ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_A(cnmO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_C(cnmO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_shifts(cnmO)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
